{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tune with MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Follow steps below to get started with a jupyter notebook for how to train a Towhee operator. This example fine-tunes a pretrained model (eg. resnet-18) with the MNIST dataset.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Setup Operator\n",
    "Create timm operator and load model by name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings #\n",
    "warnings.filterwarnings(\"ignore\") #\n",
    "import towhee\n",
    "# Set num_classes=10 for MNIST dataset\n",
    "op = towhee.ops.image_embedding.timm(model_name='resnet18', num_classes=10).get_op()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Configure Trainer:\n",
    "Modify training configurations on top of default values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from towhee.trainer.training_config import TrainingConfig\n",
    "\n",
    "training_config = TrainingConfig(\n",
    "    batch_size=64,\n",
    "    epoch_num=2,\n",
    "    output_dir='mnist_output'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Prepare Dataset\n",
    "The example here uses the MNIST dataset for both training and evaluation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0047ca842b0f42b0aeb5a84264addec0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/9912422 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "235637a1b36d43c29f34815faf846f79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/28881 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cfe0568c7dd44233a636f0aa13a8bc38",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1648877 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
      "\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0eb94aa9f2c4228b9ec99d3f57b8b2a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4542 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from torchvision import transforms\n",
    "from towhee import dataset\n",
    "from torchvision.transforms import Lambda\n",
    "mean = 0.1307\n",
    "std = 0.3081\n",
    "mnist_transform = transforms.Compose([transforms.ToTensor(),\n",
    "                                          Lambda(lambda x: x.repeat(3, 1, 1)),\n",
    "                                          transforms.Normalize(mean=[mean] * 3, std=[std] * 3)])\n",
    "train_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=True)\n",
    "eval_data = dataset('mnist', transform=mnist_transform, download=True, root='data', train=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Start Training\n",
    "Start to train mnist, it will take about 30-100 minutes on a cpu machine. If you train on a gpu machine, it will be much faster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-14 16:44:06,029 - 140010757764928 - trainer.py-trainer:319 - WARNING: TrainingConfig(output_dir='mnist_output', overwrite_output_dir=True, eval_strategy='epoch', eval_steps=None, batch_size=64, val_batch_size=-1, seed=42, epoch_num=2, dataloader_pin_memory=True, dataloader_drop_last=True, dataloader_num_workers=0, lr=5e-05, metric='Accuracy', print_steps=None, load_best_model_at_end=False, early_stopping={'monitor': 'eval_epoch_metric', 'patience': 4, 'mode': 'max'}, model_checkpoint={'every_n_epoch': 1}, tensorboard={'log_dir': None, 'comment': ''}, loss='CrossEntropyLoss', optimizer='Adam', lr_scheduler_type='linear', warmup_ratio=0.0, warmup_steps=0, device_str=None, sync_bn=False, freeze_bn=False)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04688367570442dcafe1af9bfb9857a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/937 [00:00<?, ?step/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e0ca5e8c11d0423592ce7b6a6ac3fb14",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/937 [00:00<?, ?step/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "op.train(training_config, train_dataset=train_data, eval_dataset=eval_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Observing epoch progress bars, if loss decreases while metric increases, then you are training the model properly.\n",
    "# 5. Predict after Training\n",
    "After training, you can make new predictions with the operator. Comparing new predicted results with actual labels, you can evaluate the fine-tuned model with accuracy.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAANd0lEQVR4nO3db6hc9Z3H8c9nbYVogjEGQ2LC2q0RuqyaaPwDKUvWpMVNhBixiz5YXDYhfVAhxX2wkhUrxIWwbCvog8KNhmaXbmpBxRiWbSWEzRaxeBOMxmarWc22t7kkmDyoESSr+e6De1Ku8c5vrnPOzJnk+37BMDPne8+cL0M+OWfmN+f8HBECcPH7o7YbADAYhB1IgrADSRB2IAnCDiTxpUFuzDZf/QN9FhGeanmtPbvtu2z/2vYR24/UeS0A/eVex9ltXyLpHUnfkDQm6XVJD0TErwrrsGcH+qwfe/bbJB2JiPci4oykn0haW+P1APRRnbBfI+m3k56PVcs+w/ZG26O2R2tsC0BNdb6gm+pQ4XOH6RExImlE4jAeaFOdPfuYpEWTni+UdKxeOwD6pU7YX5e02PZXbF8q6X5Ju5ppC0DTej6Mj4hPbD8k6WeSLpG0PSLebqwzAI3qeeitp43xmR3ou778qAbAhYOwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQGOmUz+mPGjBkda3Pnzi2uO3PmzGJ9w4YNxfrixYuL9fnz53es7du3r7jusWPlOUeefvrpYv3MmTPFejbs2YEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcbZh8Add9xRrK9Zs6ZYv/POOzvWbr/99uK69pQTfv5BP2f5vfnmm2tt+4orrijWH3vssS/c08WsVthtH5X0oaRPJX0SEcuaaApA85rYs/9FRHzQwOsA6CM+swNJ1A17SPq57f22N071B7Y32h61PVpzWwBqqHsYvzwijtm+WtIrtv87Ij5zdkNEjEgakSTb/fu2B0BRrT17RByr7k9IelHSbU00BaB5PYfd9uW2Z517LOmbkg411RiAZtU5jJ8n6cVqnPZLkv4tIv6jka4uMIsWLSrWn3nmmWJ91apVxXo/x7ovZNdff33bLVxQeg57RLwn6aYGewHQRwy9AUkQdiAJwg4kQdiBJAg7kASnuDag2ymqK1eu7Ov233nnnY61BQsWFNc9cOBAsb5ly5aeejpn9erVHWsPP/xwrdfudhnr0mWyT58+XWvbFyL27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsDeh2SeSPPvqoWN+1a1exvn///mL9yJEjHWu7d+8urttve/fu7Vh76qmniusePXq0WF+yZEmxPnv27I41xtkBXLQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJD/IyxRfrjDALFy4s1q+66qpi/eDBg022c8Hodgnu999/v1g/efJksX7DDTd0rJ04caK47oUsIqach5s9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwfnsDRgbG6tVz+r++++vtX7pXHnp4h5L70XXPbvt7bZP2D40adkc26/Yfre6v7K/bQKoazqH8T+SdNd5yx6RtCciFkvaUz0HMMS6hj0i9kk6dd7itZJ2VI93SLqn2bYANK3Xz+zzImJckiJi3PbVnf7Q9kZJG3vcDoCG9P0LuogYkTQiXbwnwgAXgl6H3o7bni9J1T1fewJDrtew75L0YPX4QUkvNdMOgH7pehhve6ekFZLm2h6T9D1JWyX91PZ6Sb+R9K1+NokL1y233NKx9uijjw6wE3QNe0Q80KG0suFeAPQRP5cFkiDsQBKEHUiCsANJEHYgCS4ljVpmzJhRrJemk16wYEFx3W6Xkl6+fHmxPj4+XqxfrLiUNJAcYQeSIOxAEoQdSIKwA0kQdiAJwg4kwaWkUcuGDRuK9Xnz5nWsnT17trjuyMhIsZ51HL1X7NmBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnOZ0fRvffeW6xv3769WJ85c2bH2quvvlpcd926dcX6yZMni/WsOJ8dSI6wA0kQdiAJwg4kQdiBJAg7kARhB5JgnD250ji4JO3bt69Yv/HGG3ve9uzZs4v106dP9/zamfU8zm57u+0Ttg9NWva47d/ZfqO6rW6yWQDNm85h/I8k3TXF8icjYkl1+/dm2wLQtK5hj4h9kk4NoBcAfVTnC7qHbL9ZHeZf2emPbG+0PWp7tMa2ANTUa9h/KOmrkpZIGpf0/U5/GBEjEbEsIpb1uC0ADegp7BFxPCI+jYizkrZJuq3ZtgA0raew254/6ek6SYc6/S2A4dD1uvG2d0paIWmu7TFJ35O0wvYSSSHpqKRv969F1HH33XcX6y+//HKx3u3a7t1s2bKlY41x9MHqGvaIeGCKxc/2oRcAfcTPZYEkCDuQBGEHkiDsQBKEHUiCU1wvAqXLPe/YsaO47mWXXVasd/v38dprrxXrq1at6lj7+OOPi+uiN1xKGkiOsANJEHYgCcIOJEHYgSQIO5AEYQeS6HrWG9pXZ9rkGTNm1Np2t3H0zZs3F+uMpQ8P9uxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kATnsw+Bbpd73rlzZ7FeZyz9iSeeKNa3bt1arDOOPnw4nx1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcfQiMjo4W60uXLu35tQ8cOFCs33rrrcX6pZde2vO2JenMmTMda7Nnzy6uO2vWrFrbXr9+fcfaihUriutu2rSpWD948GAvLQ1Ez+PsthfZ3mv7sO23bW+qls+x/Yrtd6v7K5tuGkBzpnMY/4mkv4uIr0m6Q9J3bP+ppEck7YmIxZL2VM8BDKmuYY+I8Yg4UD3+UNJhSddIWivp3NxCOyTd06ceATTgC12Dzva1kpZK+qWkeRExLk38h2D76g7rbJS0sWafAGqadthtz5T0vKTvRsTv7Sm/A/iciBiRNFK9Bl/QAS2Z1tCb7S9rIug/jogXqsXHbc+v6vMlnehPiwCa0HXP7old+LOSDkfEDyaVdkl6UNLW6v6lvnSYQLehtTrDo91e+7nnnivW58yZ0/O2JenUqVMdazfddFNx3euuu67Wtku6HZkuXry4WB/mobdOpnMYv1zSX0t6y/Yb1bLNmgj5T22vl/QbSd/qS4cAGtE17BHxC0md/htc2Ww7APqFn8sCSRB2IAnCDiRB2IEkCDuQBFM2J3ffffcV6/08BbrbWPcgT7/OgD07kARhB5Ig7EAShB1IgrADSRB2IAnCDiTBOPsQePLJJ4v1NWvWFOvdzr3Oatu2bR1rx48fL667e/fupttpHXt2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCKZuBi0zPUzYDuDgQdiAJwg4kQdiBJAg7kARhB5Ig7EASXcNue5HtvbYP237b9qZq+eO2f2f7jeq2uv/tAuhV1x/V2J4vaX5EHLA9S9J+SfdI+itJpyPin6e9MX5UA/Rdpx/VTGd+9nFJ49XjD20flnRNs+0B6Lcv9Jnd9rWSlkr6ZbXoIdtv2t5u+8oO62y0PWp7tF6rAOqY9m/jbc+U9J+S/jEiXrA9T9IHkkLSFk0c6v9tl9fgMB7os06H8dMKu+0vS9ot6WcR8YMp6tdK2h0Rf9bldQg70Gc9nwjjiak2n5V0eHLQqy/uzlkn6VDdJgH0z3S+jf+6pP+S9Jaks9XizZIekLREE4fxRyV9u/oyr/Ra7NmBPqt1GN8Uwg70H+ezA8kRdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkuh6wcmGfSDpfyc9n1stG0bD2tuw9iXRW6+a7O2POxUGej775zZuj0bEstYaKBjW3oa1L4neejWo3jiMB5Ig7EASbYd9pOXtlwxrb8Pal0RvvRpIb61+ZgcwOG3v2QEMCGEHkmgl7Lbvsv1r20dsP9JGD53YPmr7rWoa6lbnp6vm0Dth+9CkZXNsv2L73ep+yjn2WuptKKbxLkwz3up71/b05wP/zG77EknvSPqGpDFJr0t6ICJ+NdBGOrB9VNKyiGj9Bxi2/1zSaUn/cm5qLdv/JOlURGyt/qO8MiL+fkh6e1xfcBrvPvXWaZrxv1GL712T05/3oo09+22SjkTEexFxRtJPJK1toY+hFxH7JJ06b/FaSTuqxzs08Y9l4Dr0NhQiYjwiDlSPP5R0bprxVt+7Ql8D0UbYr5H020nPxzRc872HpJ/b3m97Y9vNTGHeuWm2qvurW+7nfF2n8R6k86YZH5r3rpfpz+tqI+xTTU0zTON/yyPiZkl/Kek71eEqpueHkr6qiTkAxyV9v81mqmnGn5f03Yj4fZu9TDZFXwN539oI+5ikRZOeL5R0rIU+phQRx6r7E5Je1MTHjmFy/NwMutX9iZb7+YOIOB4Rn0bEWUnb1OJ7V00z/rykH0fEC9Xi1t+7qfoa1PvWRthfl7TY9ldsXyrpfkm7Wujjc2xfXn1xItuXS/qmhm8q6l2SHqwePyjppRZ7+Yxhmca70zTjavm9a33684gY+E3Sak18I/8/kv6hjR469PUnkg5Wt7fb7k3STk0c1v2fJo6I1ku6StIeSe9W93OGqLd/1cTU3m9qIljzW+rt65r4aPimpDeq2+q237tCXwN53/i5LJAEv6ADkiDsQBKEHUiCsANJEHYgCcIOJEHYgST+HxGKPl6FUpniAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this picture is number 2\n"
     ]
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import random\n",
    "\n",
    "# get random picture and predict it.\n",
    "img_index = random.randint(0, len(eval_data))\n",
    "img = eval_data.dataset[img_index][0]\n",
    "img = img.numpy().transpose(1, 2, 0)  # (C, H, W) -> (H, W, C)\n",
    "pil_img = img * std + mean\n",
    "plt.imshow(pil_img)\n",
    "plt.show()\n",
    "test_img = eval_data.dataset[img_index][0].unsqueeze(0).to(op.trainer.configs.device)\n",
    "out = op.trainer.predict(test_img)\n",
    "predict_num = torch.argmax(torch.softmax(out, dim=-1)).item()\n",
    "print('this picture is number {}'.format(predict_num))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}